/*
 * vim:syntax=c
 * A small module to implement missing C99 math capabilities required by numpy
 *
 * Please keep this independent of python ! Only basic types (npy_longdouble)
 * can be used, otherwise, pure C, without any use of Python facilities
 *
 * How to add a function to this section
 * -------------------------------------
 *
 * Say you want to add `foo`, these are the steps and the reasons for them.
 *
 * 1) Add foo to the appropriate list in the configuration system. The
 *    lists can be found in numpy/_core/setup.py lines 63-105. Read the
 *    comments that come with them, they are very helpful.
 *
 * 2) The configuration system will define a macro HAVE_FOO if your function
 *    can be linked from the math library. The result can depend on the
 *    optimization flags as well as the compiler, so can't be known ahead of
 *    time. If the function can't be linked, then either it is absent, defined
 *    as a macro, or is an intrinsic (hardware) function.
 *
 *    i) Undefine any possible macros:
 *
 *    #ifdef foo
 *    #undef foo
 *    #endif
 *
 *    ii) Avoid as much as possible to declare any function here. Declaring
 *    functions is not portable: some platforms define some function inline
 *    with a non standard identifier, for example, or may put another
 *    identifier which changes the calling convention of the function. If you
 *    really have to, ALWAYS declare it for the one platform you are dealing
 *    with:
 *
 *    Not ok:
 *        double exp(double a);
 *
 *    Ok:
 *        #ifdef SYMBOL_DEFINED_WEIRD_PLATFORM
 *        double exp(double);
 *        #endif
 *
 * Some of the code is taken from msun library in FreeBSD, with the following
 * notice:
 *
 * ====================================================
 * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
 *
 * Developed at SunPro, a Sun Microsystems, Inc. business.
 * Permission to use, copy, modify, and distribute this
 * software is freely granted, provided that this notice
 * is preserved.
 * ====================================================
 */
#include "npy_math_private.h"
#ifdef _MSC_VER
#  include <intrin.h>   // for __popcnt
#endif

/* Magic binary numbers used by bit_count
 * For type T, the magic numbers are computed as follows:
 * Magic[0]: 01 01 01 01 01 01... = (T)~(T)0/3
 * Magic[1]: 0011 0011 0011...    = (T)~(T)0/15  * 3
 * Magic[2]: 00001111 00001111... = (T)~(T)0/255 * 15
 * Magic[3]: 00000001 00000001... = (T)~(T)0/255
 *
 * Counting bits set, in parallel
 * Based on: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
 *
 * Generic Algorithm for type T:
 * a = a - ((a >> 1) & (T)~(T)0/3);
 * a = (a & (T)~(T)0/15*3) + ((a >> 2) & (T)~(T)0/15*3);
 * a = (a + (a >> 4)) & (T)~(T)0/255*15;
 * c = (T)(a * ((T)~(T)0/255)) >> (sizeof(T) - 1) * CHAR_BIT;
*/

static const npy_uint8  MAGIC8[]  = {0x55u,                 0x33u,                 0x0Fu,                 0x01u};
static const npy_uint16 MAGIC16[] = {0x5555u,               0x3333u,               0x0F0Fu,               0x0101u};
static const npy_uint32 MAGIC32[] = {0x55555555ul,          0x33333333ul,          0x0F0F0F0Ful,          0x01010101ul};
static const npy_uint64 MAGIC64[] = {0x5555555555555555ull, 0x3333333333333333ull, 0x0F0F0F0F0F0F0F0Full, 0x0101010101010101ull};


/*
 *****************************************************************************
 **                    BLOCKLIST-ABLE BASIC MATH FUNCTIONS                  **
 *****************************************************************************
 */

/* The following functions can be blocked, even for doubles */

/* Original code by Konrad Hinsen.  */
/* Taken from FreeBSD mlib, adapted for numpy
 *
 * XXX: we could be a bit faster by reusing high/low words for inf/nan
 * classification instead of calling npy_isinf/npy_isnan: we should have some
 * macros for this, though, instead of doing it manually
 */
NPY_INPLACE double npy_log2(double x)
{
#ifndef NPY_BLOCK_LOG2
    return log2(x);
#else
    if (!npy_isfinite(x) || x <= 0.) {
        /* special value result */
        return npy_log(x);
    }
    else {
        /*
         * fallback implementation copied from python3.4 math.log2
         * provides int(log(2**i)) == i for i 1-64 in default rounding mode.
         *
         * We want log2(m * 2**e) == log(m) / log(2) + e.  Care is needed when
         * x is just greater than 1.0: in that case e is 1, log(m) is negative,
         * and we get significant cancellation error from the addition of
         * log(m) / log(2) to e.  The slight rewrite of the expression below
         * avoids this problem.
         */
        int e;
        double m = frexp(x, &e);
        if (x >= 1.0) {
            return log(2.0 * m) / log(2.0) + (e - 1);
        }
        else {
            return log(m) / log(2.0) + e;
        }
    }
#endif
}

/* Taken from FreeBSD mlib, adapted for numpy
 *
 * XXX: we could be a bit faster by reusing high/low words for inf/nan
 * classification instead of calling npy_isinf/npy_isnan: we should have some
 * macros for this, though, instead of doing it manually
 */
/* XXX: we should have this in npy_math.h */
#define NPY_DBL_EPSILON 1.2246467991473531772E-16
NPY_INPLACE double npy_atan2(double y, double x)
{
#ifndef NPY_BLOCK_ATAN2
    return atan2(y, x);
#else
    npy_int32 k, m, iy, ix, hx, hy;
    npy_uint32 lx,ly;
    double z;

    EXTRACT_WORDS(hx, lx, x);
    ix = hx & 0x7fffffff;
    EXTRACT_WORDS(hy, ly, y);
    iy = hy & 0x7fffffff;

    /* if x or y is nan, return nan */
    if (npy_isnan(x * y)) {
        return x + y;
    }

    if (x == 1.0) {
        return npy_atan(y);
    }

    m = 2 * (npy_signbit((x)) != 0) + (npy_signbit((y)) != 0);
    if (y == 0.0) {
        switch(m) {
        case 0:
        case 1: return  y;  /* atan(+-0,+anything)=+-0 */
        case 2: return  NPY_PI;/* atan(+0,-anything) = pi */
        case 3: return -NPY_PI;/* atan(-0,-anything) =-pi */
        }
    }

    if (x == 0.0) {
        return y > 0 ? NPY_PI_2 : -NPY_PI_2;
    }

    if (npy_isinf(x)) {
        if (npy_isinf(y)) {
            switch(m) {
                case 0: return  NPY_PI_4;/* atan(+INF,+INF) */
                case 1: return -NPY_PI_4;/* atan(-INF,+INF) */
                case 2: return  3.0*NPY_PI_4;/*atan(+INF,-INF)*/
                case 3: return -3.0*NPY_PI_4;/*atan(-INF,-INF)*/
            }
        } else {
            switch(m) {
                case 0: return  NPY_PZERO;  /* atan(+...,+INF) */
                case 1: return  NPY_NZERO;  /* atan(-...,+INF) */
                case 2: return  NPY_PI;  /* atan(+...,-INF) */
                case 3: return -NPY_PI;  /* atan(-...,-INF) */
            }
        }
    }

    if (npy_isinf(y)) {
        return y > 0 ? NPY_PI_2 : -NPY_PI_2;
    }

    /* compute y/x */
    k = (iy - ix) >> 20;
    if (k > 60) {            /* |y/x| >  2**60 */
        z = NPY_PI_2 + 0.5 * NPY_DBL_EPSILON;
        m &= 1;
    } else if (hx < 0 && k < -60) {
        z = 0.0;    /* 0 > |y|/x > -2**-60 */
    } else {
        z = npy_atan(npy_fabs(y/x));        /* safe to do y/x */
    }

    switch (m) {
        case 0: return  z  ;    /* atan(+,+) */
        case 1: return -z  ;    /* atan(-,+) */
        case 2: return  NPY_PI - (z - NPY_DBL_EPSILON);/* atan(+,-) */
        default: /* case 3 */
            return  (z - NPY_DBL_EPSILON) - NPY_PI;/* atan(-,-) */
    }
#endif
}





NPY_INPLACE double npy_hypot(double x, double y)
{
#ifndef NPY_BLOCK_HYPOT
    return hypot(x, y);
#else
    double yx;

    if (npy_isinf(x) || npy_isinf(y)) {
        return NPY_INFINITY;
    }

    if (npy_isnan(x) || npy_isnan(y)) {
        return NPY_NAN;
    }

    x = npy_fabs(x);
    y = npy_fabs(y);
    if (x < y) {
        double temp = x;
        x = y;
        y = temp;
    }
    if (x == 0.) {
        return 0.;
    }
    else {
        yx = y/x;
        return x*npy_sqrt(1.+yx*yx);
    }
#endif
}

/*
 *
 * sin, cos, tan
 * sinh, cosh, tanh,
 * fabs, floor, ceil, rint, trunc
 * sqrt, log10, log, exp, expm1
 * asin, acos, atan,
 * asinh, acosh, atanh
 *
 * hypot, atan2, pow, fmod, modf
 * ldexp, frexp, cbrt
 *
 * We assume the above are always available in their double versions.
 *
 * NOTE: some facilities may be available as macro only  instead of functions.
 * For simplicity, we define our own functions and undef the macros. We could
 * instead test for the macro, but I am lazy to do that for now.
 */


/*
 * Decorate all the math functions which are available on the current platform
 */

/**begin repeat
 * #type = npy_longdouble, npy_double, npy_float#
 * #TYPE = LONGDOUBLE, DOUBLE, FLOAT#
 * #c = l,,f#
 * #C = L,,F#
 */
#undef NPY__FP_SFX
#if NPY_SIZEOF_@TYPE@ == NPY_SIZEOF_DOUBLE
    #define NPY__FP_SFX(X) X
#else
    #define NPY__FP_SFX(X) NPY_CAT(X, @c@)
#endif
/*
 * On arm64 macOS, there's a bug with sin, cos, and tan where they don't
 * raise "invalid" when given INFINITY as input.
 */
#if defined(__APPLE__) && defined(__arm64__)
#define WORKAROUND_APPLE_TRIG_BUG 1
#else
#define WORKAROUND_APPLE_TRIG_BUG 0
#endif

/**begin repeat1
 * #kind = sin,cos,tan#
 * #TRIG_WORKAROUND = WORKAROUND_APPLE_TRIG_BUG*3#
 */
NPY_INPLACE @type@ npy_@kind@@c@(@type@ x)
{
#if @TRIG_WORKAROUND@
    if (!npy_isfinite(x)) {
        return (x - x);
    }
#endif
    return NPY__FP_SFX(@kind@)(x);
}

/**end repeat1**/

#undef WORKAROUND_APPLE_TRIG_BUG

/**end repeat**/

/* Blocklist-able C99 functions */

/**begin repeat
 * #type = npy_float,npy_longdouble#
 * #TYPE = FLOAT,LONGDOUBLE#
 * #c = f,l#
 * #C = F,L#
 */
#undef NPY__FP_SFX
#if NPY_SIZEOF_@TYPE@ == NPY_SIZEOF_DOUBLE
    #define NPY__FP_SFX(X) X
#else
    #define NPY__FP_SFX(X) NPY_CAT(X, @c@)
#endif

/**begin repeat1
 * #kind = exp,log2,sqrt#
 * #KIND = EXP,LOG2,SQRT#
 */

#ifdef @kind@@c@
#undef @kind@@c@
#endif
#ifdef NPY_BLOCK_@KIND@@C@
NPY_INPLACE @type@ npy_@kind@@c@(@type@ x)
{
    return (@type@) npy_@kind@((double)x);
}
#endif

#ifndef NPY_BLOCK_@KIND@@C@
NPY_INPLACE @type@ npy_@kind@@c@(@type@ x)
{
    return NPY__FP_SFX(@kind@)(x);
}
#endif

/**end repeat1**/


/**begin repeat1
 * #kind = atan2,hypot,pow#
 * #KIND = ATAN2,HYPOT,POW#
 */
#ifdef @kind@@c@
#undef @kind@@c@
#endif
#ifdef NPY_BLOCK_@KIND@@C@
NPY_INPLACE @type@ npy_@kind@@c@(@type@ x, @type@ y)
{
    return (@type@) npy_@kind@((double)x, (double) y);
}
#endif

#ifndef NPY_BLOCK_@KIND@@C@
NPY_INPLACE @type@ npy_@kind@@c@(@type@ x, @type@ y)
{
    return NPY__FP_SFX(@kind@)(x, y);
}
#endif
/**end repeat1**/

#ifdef modf@c@
#undef modf@c@
#endif
#ifdef NPY_BLOCK_MODF@C@
NPY_INPLACE @type@ npy_modf@c@(@type@ x, @type@ *iptr)
{
    double niptr;
    double y = npy_modf((double)x, &niptr);
    *iptr = (@type@) niptr;
    return (@type@) y;
}
#endif

#ifndef NPY_BLOCK_MODF@C@
NPY_INPLACE @type@ npy_modf@c@(@type@ x, @type@ *iptr)
{
    return NPY__FP_SFX(modf)(x, iptr);
}
#endif


/**end repeat**/


#undef NPY__FP_SFX


/*
 * Non standard functions
 */

/**begin repeat
 * #type = npy_float, npy_double, npy_longdouble#
 * #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
 * #c = f, ,l#
 * #C = F, ,L#
 */
#undef NPY__FP_SFX
#if NPY_SIZEOF_@TYPE@ == NPY_SIZEOF_DOUBLE
    #define NPY__FP_SFX(X) X
#else
    #define NPY__FP_SFX(X) NPY_CAT(X, @c@)
#endif
@type@ npy_heaviside@c@(@type@ x, @type@ h0)
{
    if (npy_isnan(x)) {
        return (@type@) NPY_NAN;
    }
    else if (x == 0) {
        return h0;
    }
    else if (x < 0) {
        return (@type@) 0.0;
    }
    else {
        return (@type@) 1.0;
    }
}

#define LOGE2    NPY__FP_SFX(NPY_LOGE2)
#define LOG2E    NPY__FP_SFX(NPY_LOG2E)
#define RAD2DEG  (NPY__FP_SFX(180.0)/NPY__FP_SFX(NPY_PI))
#define DEG2RAD  (NPY__FP_SFX(NPY_PI)/NPY__FP_SFX(180.0))

NPY_INPLACE @type@ npy_rad2deg@c@(@type@ x)
{
    return x*RAD2DEG;
}

NPY_INPLACE @type@ npy_deg2rad@c@(@type@ x)
{
    return x*DEG2RAD;
}

NPY_INPLACE @type@ npy_log2_1p@c@(@type@ x)
{
    return LOG2E*npy_log1p@c@(x);
}

NPY_INPLACE @type@ npy_exp2_m1@c@(@type@ x)
{
    return npy_expm1@c@(LOGE2*x);
}

NPY_INPLACE @type@ npy_logaddexp@c@(@type@ x, @type@ y)
{
    if (x == y) {
        /* Handles infinities of the same sign without warnings */
        return x + LOGE2;
    }
    else {
        const @type@ tmp = x - y;
        if (tmp > 0) {
            return x + npy_log1p@c@(npy_exp@c@(-tmp));
        }
        else if (tmp <= 0) {
            return y + npy_log1p@c@(npy_exp@c@(tmp));
        }
        else {
            /* NaNs */
            return tmp;
        }
    }
}

NPY_INPLACE @type@ npy_logaddexp2@c@(@type@ x, @type@ y)
{
    if (x == y) {
        /* Handles infinities of the same sign without warnings */
        return x + 1;
    }
    else {
        const @type@ tmp = x - y;
        if (tmp > 0) {
            return x + npy_log2_1p@c@(npy_exp2@c@(-tmp));
        }
        else if (tmp <= 0) {
            return y + npy_log2_1p@c@(npy_exp2@c@(tmp));
        }
        else {
            /* NaNs */
            return tmp;
        }
    }
}

/*
 * Wrapper function for remainder edge cases
 * Internally calls npy_divmod*
 */
NPY_INPLACE @type@
npy_remainder@c@(@type@ a, @type@ b)
{
    @type@ mod;
    if (NPY_UNLIKELY(!b)) {
        /*
         * in2 == 0 (and not NaN): normal fmod will give the correct
         * result (always NaN). `divmod` may set additional FPE for the
         * division by zero creating an inf.
         */
        mod = npy_fmod@c@(a, b);
    }
    else {
        npy_divmod@c@(a, b, &mod);
    }
    return mod;
}

NPY_INPLACE @type@
npy_floor_divide@c@(@type@ a, @type@ b) {
    @type@ div, mod;
    if (NPY_UNLIKELY(!b)) {
        /*
         * in2 == 0 (and not NaN): normal division will give the correct
         * result (Inf or NaN). `divmod` may set additional FPE for the modulo
         * evaluating to NaN.
         */
        div = a / b;
    }
    else {
        div = npy_divmod@c@(a, b, &mod);
    }
    return div;
}

/*
 * Python version of divmod.
 *
 * The implementation is mostly copied from cpython 3.5.
 */
NPY_INPLACE @type@
npy_divmod@c@(@type@ a, @type@ b, @type@ *modulus)
{
    @type@ div, mod, floordiv;

    mod = npy_fmod@c@(a, b);
    if (NPY_UNLIKELY(!b)) {
        /* b == 0 (not NaN): return result of fmod. For IEEE is nan */
        *modulus = mod;
        return a / b;
    }

    /* a - mod should be very nearly an integer multiple of b */
    div = (a - mod) / b;

    /* adjust fmod result to conform to Python convention of remainder */
    if (mod) {
        if (isless(b, (@type@)0) != isless(mod, (@type@)0)) {
            mod += b;
            div -= 1.0@c@;
        }
    }
    else {
        /* if mod is zero ensure correct sign */
        mod = npy_copysign@c@(0, b);
    }

    /* snap quotient to nearest integral value */
    if (div) {
        floordiv = npy_floor@c@(div);
        if (isgreater(div - floordiv, 0.5@c@))
            floordiv += 1.0@c@;
    }
    else {
        /* if div is zero ensure correct sign */
        floordiv = npy_copysign@c@(0, a/b);
    }

    *modulus = mod;
    return floordiv;
}

#undef LOGE2
#undef LOG2E
#undef RAD2DEG
#undef DEG2RAD
#undef NPY__FP_SFX
/**end repeat**/

/**begin repeat
 *
 * #type = npy_uint, npy_ulong, npy_ulonglong#
 * #c = u,ul,ull#
 */
NPY_INPLACE @type@
npy_gcd@c@(@type@ a, @type@ b)
{
    @type@ c;
    while (a != 0) {
        c = a;
        a = b%a;
        b = c;
    }
    return b;
}

NPY_INPLACE @type@
npy_lcm@c@(@type@ a, @type@ b)
{
    @type@ gcd = npy_gcd@c@(a, b);
    return gcd == 0 ? 0 : a / gcd * b;
}
/**end repeat**/

/**begin repeat
 *
 * #type = (npy_int, npy_long, npy_longlong)*2#
 * #c = (,l,ll)*2#
 * #func=gcd*3,lcm*3#
 */
NPY_INPLACE @type@
npy_@func@@c@(@type@ a, @type@ b)
{
    return npy_@func@u@c@(a < 0 ? -a : a, b < 0 ? -b : b);
}
/**end repeat**/

/* Unlike LCM and GCD, we need byte and short variants for the shift operators,
 * since the result is dependent on the width of the type
 */
/**begin repeat
 *
 * #type = byte, short, int, long, longlong#
 * #c = hh,h,,l,ll#
 */
/**begin repeat1
 *
 * #u         = u,#
 * #is_signed = 0,1#
 */
NPY_INPLACE npy_@u@@type@
npy_lshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
{
    if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
        return a << b;
    }
    else {
        return 0;
    }
}
NPY_INPLACE npy_@u@@type@
npy_rshift@u@@c@(npy_@u@@type@ a, npy_@u@@type@ b)
{
    if (NPY_LIKELY((size_t)b < sizeof(a) * CHAR_BIT)) {
        return a >> b;
    }
#if @is_signed@
    else if (a < 0) {
        return (npy_@u@@type@)-1;  /* preserve the sign bit */
    }
#endif
    else {
        return 0;
    }
}
/**end repeat1**/
/**end repeat**/

#define __popcnt32 __popcnt
/**begin repeat
 *
 * #type  = ubyte, ushort, uint, ulong, ulonglong#
 * #STYPE = BYTE,  SHORT,  INT,  LONG,  LONGLONG#
 * #c     = hh,    h,      ,     l,     ll#
 */
#undef TO_BITS_LEN
#if 0
/**begin repeat1
 * #len = 8, 16, 32, 64#
 */
#elif NPY_BITSOF_@STYPE@ == @len@
    #define TO_BITS_LEN(X) X##@len@
/**end repeat1**/
#endif


NPY_INPLACE uint8_t
npy_popcount_parallel@c@(npy_@type@ a)
{
    a = a - ((a >> 1) & (npy_@type@) TO_BITS_LEN(MAGIC)[0]);
    a = ((a & (npy_@type@) TO_BITS_LEN(MAGIC)[1])) + ((a >> 2) & (npy_@type@) TO_BITS_LEN(MAGIC)[1]);
    a = (a + (a >> 4)) & (npy_@type@) TO_BITS_LEN(MAGIC)[2];
    return (npy_@type@) (a * (npy_@type@) TO_BITS_LEN(MAGIC)[3]) >> ((NPY_SIZEOF_@STYPE@ - 1) * CHAR_BIT);
}

NPY_INPLACE uint8_t
npy_popcountu@c@(npy_@type@ a)
{
/* use built-in popcount if present, else use our implementation */
#if (defined(__clang__) || defined(__GNUC__)) && NPY_BITSOF_@STYPE@ >= 32
    return __builtin_popcount@c@(a);
#elif defined(_MSC_VER) && NPY_BITSOF_@STYPE@ >= 16 && !defined(_M_ARM64) && !defined(_M_ARM)
    /* no builtin __popcnt64 for 32 bits */
    #if defined(_WIN64) || (defined(_WIN32) && NPY_BITSOF_@STYPE@ != 64)
        return TO_BITS_LEN(__popcnt)(a);
    /* split 64 bit number into two 32 bit ints and return sum of counts */
    #elif (defined(_WIN32) && NPY_BITSOF_@STYPE@ == 64)
        npy_uint32 left  = (npy_uint32) (a>>32);
        npy_uint32 right = (npy_uint32) a;
        return __popcnt32(left) + __popcnt32(right);
    #endif
#else
    return npy_popcount_parallel@c@(a);
#endif
}
/**end repeat**/

/**begin repeat
 *
 * #type = byte, short, int, long, longlong#
 * #c    = hh,   h,     ,    l,    ll#
 */
NPY_INPLACE uint8_t
npy_popcount@c@(npy_@type@ a)
{
    /* Return popcount of abs(a) */
    return npy_popcountu@c@(a < 0 ? -a : a);
}
/**end repeat**/

